import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.main import main, parse_args
from base_grid_hpo import BaseGridHPO


class EOTHPO(BaseGridHPO):

    def __init__(self, base_args):
        super(EOTHPO, self).__init__()
        self.base_args = base_args

    def select_hyperparams(self):
        for setting in self.grid():
            run_exp(setting, self.base_args)


def run_exp(setting, base_args):
    temp = sys.argv
    sys.argv = base_args
    for name in setting:
        sys.argv.append(name)
        sys.argv.append(str(setting[name]))
    a = parse_args()
    try:
        log = main(a)
    except Exception as e:
        print('An exception occurred: {}'.format(e))
    sys.argv = temp


def ER(eot_hpo):
    eot_hpo.register_hyperparam('--buffer_size', None, None, [5120])


def DERpp(eot_hpo):
    #eot_hpo.register_hyperparam('--buffer_size', None, None, [5120])
    eot_hpo.register_hyperparam('--alpha', None, None, [0.2, 0.5, 1.0])
    eot_hpo.register_hyperparam('--beta', None, None, [0.2, 0.5, 1.0])


def ESMER(eot_hpo):
    eot_hpo.register_hyperparam('--loss_margin', None, None, [1.5, 1.2, 1.0])


def icarl(eot_hpo):
    eot_hpo.register_hyperparam('--buffer_size', None, None, [5120])


if __name__ == '__main__':
    sys.argv += ['--validation', '1', '--nowand', '1']
    eot_hpo = EOTHPO(sys.argv)
    eot_hpo.register_hyperparam('--lr', None, None, [0.2, 0.15, 0.1, 0.075, 0.05, 0.03, 0.01, 0.0075, 0.005, 0.0025])
    #eot_hpo.register_hyperparam('--optim_wd', None, None, [0.0, 0.00001])
    #eot_hpo.register_hyperparam('--optim_mom', None, None, [0.0, 0.99])
    #eot_hpo.register_hyperparam('--n_epochs', None, None, [20, 50, 100])
    #if 'er' in sys.argv:
    #    ER(eot_hpo)
    if 'derpp' in sys.argv:
        DERpp(eot_hpo)

    if 'esmer' in sys.argv:
        ESMER(eot_hpo)

    #if 'icarl' in sys.argv:
    #    icarl(eot_hpo)

    eot_hpo.select_hyperparams()



